import keras
import keras.layers as kl
import keras.models as km
import keras.backend as kb
import tensorflow as tf


class Uint8ModelTestv34:
    def __init__(self, scale=3):
        self.scale = scale

        self.Depth2Space = kl.Lambda(lambda x: tf.nn.depth_to_space(x, 3))

        self.initializer = tf.keras.initializers.VarianceScaling(scale=0.1*2., mode='fan_in', distribution='truncated_normal')

        self.Clamp = kl.Lambda(lambda x: kb.clip(x,0,1))

        self.Slice_4_0 = kl.Lambda(lambda x: x[:, :, :, :x.shape[-1]//4])
        self.Slice_4_1 = kl.Lambda(lambda x: x[..., x.shape[-1] // 4:2 * x.shape[-1] // 4])
        self.Slice_4_2 = kl.Lambda(lambda x: x[..., 2 * x.shape[-1] // 4:3 * x.shape[-1] // 4])
        self.Slice_4_3 = kl.Lambda(lambda x: x[..., 3 * x.shape[-1] // 4:])



    def GConv10(self,x_in,slice=True,out_filter=32):

        if slice:
            x_0 = self.Slice_4_0(x_in)
            x_1 = self.Slice_4_1(x_in)
            x_2 = self.Slice_4_2(x_in)
            x_3 = self.Slice_4_3(x_in)
            #channel split
        else:
            x_0 = x_in
            x_1 = x_in
            x_2 = x_in
            x_3 = x_in


        x_0 = kl.Conv2D(8, (3,3),padding="same",kernel_initializer=self.initializer)(x_0)
        x_0 = kl.ReLU()(x_0)

        x_1 = kl.Conv2D(8, (3, 3), padding="same", kernel_initializer=self.initializer)(x_1)
        x_1 = kl.ReLU()(x_1)

        x_2 = kl.Conv2D(8, (3, 3), padding="same", kernel_initializer=self.initializer)(x_2)
        x_2 = kl.ReLU()(x_2)

        x_3 = kl.Conv2D(8, (3, 3), padding="same", kernel_initializer=self.initializer)(x_3)
        x_3 = kl.ReLU()(x_3)

        x_out = kl.Concatenate()((x_0,x_1,x_2,x_3))
        x_out = kl.Conv2D(out_filter, (1, 1), padding="same", kernel_initializer=self.initializer)(x_out)

        return x_out


    def network(self):
        lr_image_input = kl.Input(shape=(None, None, 3), name="lr_image")  # 128

        lr_image_input_norm = lr_image_input

        x_f = kl.Conv2D(16, (3, 3), strides=1, padding="same", kernel_initializer=self.initializer)(lr_image_input_norm)

        x_p0 = self.GConv10(lr_image_input_norm,slice=False,out_filter=32)
        x_p1 = self.GConv10(x_p0, slice=True,out_filter=32)
        x_p2 = self.GConv10(x_p1, slice=True,out_filter=32)
        x_p3 = self.GConv10(x_p2, slice=True,out_filter=32)
        #
        x_p = kl.Concatenate()((x_p3,x_f))
        x_p = kl.Conv2D(32, (1, 1), strides=1, padding="same", kernel_initializer=self.initializer)(x_p)
        x_p = kl.ReLU()(x_p)

        x_d = kl.Conv2D(27, (3, 3), strides=1, padding="same", kernel_initializer=self.initializer)(x_p)
        diff = self.Depth2Space(x_d)

        hr_img_reconstructed_norm = diff
        hr_img_reconstructed_norm = self.Clamp(hr_img_reconstructed_norm)


        hr_img_reconstructed = hr_img_reconstructed_norm


        model = km.Model(lr_image_input, hr_img_reconstructed, name="Uint8ModelTestv34")

        return model
